# Read the csv file
import pandas as pd
import os
import sys
import json
import collections
import random
import math
import argparse
import time
import numpy as np
from tqdm import tqdm
import torch
import torch.nn as nn
import torch.optim as optim
from utils import print_local_time
from model_base import BubbleEmbedBase
from torch.utils.data import Dataset, DataLoader
from transformers import BertTokenizer
import scipy

class DataTrain(Dataset):
    def __init__(self, args, tokenizer,data_path,desc_file):
        self.args = args
        self.data, self.desc = self._load_data(data_path, desc_file)
        self.tokenizer = tokenizer
    
    def _load_data(self, data_path,desc_file):
        # Load the data from the csv file
        data = pd.read_csv(data_path, header=0, index_col=None)
        desc = json.load(open(desc_file, 'r'))
        return data,desc
    
    def __len__(self):
        return len(self.data)

    def process_encode(self, encode):
        # Process the encode to get the input ids, attention mask, and token type ids
        input_ids = encode['input_ids'].squeeze(0)
        attention_mask = encode['attention_mask'].squeeze(0)
        token_type_ids = encode['token_type_ids'].squeeze(0)
        if self.args.cuda:
            input_ids = input_ids.cuda()
            attention_mask = attention_mask.cuda()
            token_type_ids = token_type_ids.cuda()
        encode["input_ids"] = input_ids
        encode["attention_mask"] = attention_mask
        encode["token_type_ids"] = token_type_ids
        return encode
    
    def generate_train_instance_id(self, idx):
        movieA = str(int(self.data.iloc[idx]['movieA']))
        movieB = str(int(self.data.iloc[idx]['movieB']))
        # Get the contexts for the movies
        contextA = self.desc[movieA]
        contextB = self.desc[movieB]
        encodeA = self.tokenizer(contextA,padding="max_length", truncation=True, return_tensors='pt',max_length=30)
        processedA = self.process_encode(encodeA)
        encodeB = self.tokenizer(contextB,padding="max_length", truncation=True, return_tensors='pt',max_length=30)
        processedB = self.process_encode(encodeB)
        # Get the conditional probability
        prob = self.data.iloc[idx]['conditionalProb']
        return processedA, processedB, prob

    def __getitem__(self, idx):
        encode_A, encode_B, prob = self.generate_train_instance_id(idx)
        if self.args.cuda:
            prob = torch.tensor(prob, dtype=torch.float).cuda()
        else:
            prob = torch.tensor(prob, dtype=torch.float)
        return encode_A, encode_B, prob
    
class DataTest(Dataset):
    def __init__(self, args, tokenizer,data_path,desc_file):
        self.args = args
        self.data, self.desc = self._load_data(data_path, desc_file)
        self.tokenizer = tokenizer
    
    def _load_data(self, data_path,desc_file):
        # Load the data from the csv file
        data = pd.read_csv(data_path, header=0, index_col=None)
        desc = json.load(open(desc_file, 'r'))
        return data,desc
    
    def __len__(self):
        return len(self.data)

    def process_encode(self, encode):
        # Process the encode to get the input ids, attention mask, and token type ids
        input_ids = encode['input_ids'].squeeze(0)
        attention_mask = encode['attention_mask'].squeeze(0)
        token_type_ids = encode['token_type_ids'].squeeze(0)
        if self.args.cuda:
            input_ids = input_ids.cuda()
            attention_mask = attention_mask.cuda()
            token_type_ids = token_type_ids.cuda()
        encode["input_ids"] = input_ids
        encode["attention_mask"] = attention_mask
        encode["token_type_ids"] = token_type_ids
        return encode
    
    def generate_train_instance_id(self, idx):
        movieA = str(int(self.data.iloc[idx]['movieA']))
        movieB = str(int(self.data.iloc[idx]['movieB']))
        # Get the contexts for the movies
        contextA = self.desc[movieA]
        contextB = self.desc[movieB]
        encodeA = self.tokenizer(contextA,padding="max_length", truncation=True, return_tensors='pt',max_length=30)
        processedA = self.process_encode(encodeA)
        encodeB = self.tokenizer(contextB,padding="max_length", truncation=True, return_tensors='pt',max_length=30)
        processedB = self.process_encode(encodeB)
        # Get the conditional probability
        prob = self.data.iloc[idx]['conditionalProb']
        return processedA, processedB, prob , movieA, movieB

    def __getitem__(self, idx):
        encode_A, encode_B, prob, movA, movB = self.generate_train_instance_id(idx)
        if self.args.cuda:
            prob = torch.tensor(prob, dtype=torch.float).cuda()
        else:
            prob = torch.tensor(prob, dtype=torch.float)
        return encode_A, encode_B, prob, movA, movB

class LabelClassfnExp(object):
    def __init__(self,args):
        self.args = args
        self.tokenizer = self.__load_tokenizer__()
        self.train_loader, self.train_set = self.load_data(self.args, self.tokenizer,"train")
        self.test_loader, self.test_set = self.load_data(self.args, self.tokenizer, "test")
        self.model = BubbleEmbedBase(args)
        self.optimizer_pretrain, self.optimizer_projection = self._select_optimizer()
        self._set_device()
        self._set_seed(self.args.seed)
        self.setting = self.args
        self.exp_setting = (
            str(self.args.dataset)
            + "_"
            + str(self.args.expID)
            + "_"
            + str(self.args.epochs)
            + "_"
            + str(self.args.embed_size)
            + "_"
            + str(self.args.batch_size)
            + "_"
            + str(self.args.lr)
            + "_"
            + str(self.args.phi)
            + "_"
            + str(self.args.regularwt)
            + "_"
            + str(self.args.probwt)
            + "_"
            + str(self.args.seed)
            + "_"
            + str(self.args.version)
        )

        # Loss functions
        self.intersection_loss = nn.MSELoss()
        self.regular_loss = nn.MSELoss()
        self.prob_loss = nn.MSELoss()
        self.bubble_size_loss = nn.MSELoss()

        # Additional parameters
        self.num_dimensions = self.args.embed_size
        self.volume_factor = (math.pi ** (args.embed_size / 2)) / math.gamma((args.embed_size / 2) + 1)

    
    def load_data(self, args, tokenizer, mode):
        data_dir = "../data/ml-latest/processed"
        if mode == "train":
            data_path = os.path.join(data_dir, f'train_samples_{self.args.minratingnum}.csv')
            desc_file = os.path.join(data_dir, f'movie_contexts_{self.args.minratingnum}.json')
            dataset = DataTrain(args, tokenizer, data_path, desc_file)
        elif mode == "test":
            data_path = os.path.join(data_dir, f'test_samples_{self.args.minratingnum}.csv')
            desc_file = os.path.join(data_dir, f'movie_contexts_{self.args.minratingnum}.json')
            dataset = DataTest(args, tokenizer, data_path, desc_file)
        
        dataloader = DataLoader(dataset, batch_size=args.batch_size, shuffle=True)
        return dataloader, dataset
    
    def __load_tokenizer__(self):
        tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
        print("Tokenizer Loaded!")
        return tokenizer
    
    def _select_optimizer(self):
        pre_train_parameters = [
            {
                "params": [
                    p
                    for n, p in self.model.named_parameters()
                    if n.startswith("pre_train")
                ],
                "weight_decay": 0.0,
            },
        ]
        projection_parameters = [
            {
                "params": [
                    p
                    for n, p in self.model.named_parameters()
                    if n.startswith("projection")
                ],
                "weight_decay": 0.0,
            },
        ]

        if self.args.optim == "adam":
            optimizer_pretrain = optim.Adam(pre_train_parameters, lr=self.args.lr)
            optimizer_projection = optim.Adam(
                projection_parameters, lr=self.args.lr_projection
            )
        elif self.args.optim == "adamw":
            optimizer_pretrain = optim.AdamW(
                pre_train_parameters, lr=self.args.lr, eps=self.args.eps
            )
            optimizer_projection = optim.AdamW(
                projection_parameters, lr=self.args.lr_projection, eps=self.args.eps
            )

        return optimizer_pretrain, optimizer_projection

    def _set_device(self):
        if self.args.cuda:
            self.model = self.model.cuda()

    def center_distance(self, center1, center2):
        return torch.linalg.norm(center1 - center2, 2,-1)

    def bubble_volume(self,delta,temperature=0.1):
        # Ensure valid radii (avoid negative or zero values)
        valid_mask = (delta > 0).float()
        
        # Get the number of dimensions (d)
        volume = self.volume_factor * (torch.pow(delta,self.num_dimensions))

        # Apply mask to set volume to 0 if radius is invalid
        return (volume * valid_mask)

    def bubble_regularization(self, delta):
        zeros = torch.zeros_like(delta)
        ones = torch.ones_like(delta)
        min_radius = torch.ones_like(delta) * self.args.phi
        
        # Create mask for bubbles smaller than minimum size
        small_bubble_mask = torch.where(delta < self.args.phi, ones, zeros)
        
        # Apply mask to focus loss only on small bubbles
        # Calculate MSE between actual and minimum radius for small bubbles
        regular_loss = self.bubble_size_loss(
            torch.mul(delta, small_bubble_mask), 
            torch.mul(min_radius, small_bubble_mask)
        )
        
        return regular_loss
           
    def radial_intersection_cached(self, delta1, delta2, dist_center):
        sum_radius = delta1 + delta2
        if dist_center.ndim == 1:
            dist_center = dist_center.unsqueeze(1)
        mask = (dist_center < sum_radius).float()
        intersection_radius = mask * ((sum_radius - dist_center) / 2)
        intersection_radius = torch.min(intersection_radius, torch.min(delta1, delta2))
        return intersection_radius
        
    def condition_score_cached(self, radius_A, radius_B, dist_center):
        inter_delta = self.radial_intersection_cached(
            radius_A, radius_B, dist_center
        )
        mask = (inter_delta > 0).float()
        masked_inter_delta = inter_delta * mask
        # Conditioned on B
        score_pre = masked_inter_delta / radius_B
        scores = torch.pow(score_pre,self.num_dimensions)
        return scores.squeeze()
    
    def cond_prob_loss_cached(self, radius_A, radius_B, dist_center, targets):
        score = self.condition_score_cached(radius_A, radius_B, dist_center)
        # convert the targets to a tensor, if not already
        score = score.clamp(1e-7, 1-1e-7)
        loss = self.intersection_loss(score, targets)
        return loss

    def compute_loss(self, encode_A, encode_B, targets):
        center_A, radius_A = self.model(encode_A)
        center_B, radius_B = self.model(encode_B)
        
        c_dist = self.center_distance(center_A, center_B)
        # Regularization Loss
        regular_loss = self.bubble_regularization(radius_B)
        regular_loss += self.bubble_regularization(radius_A)

        # Calculating the conditional probability
        cond_prob_loss = self.cond_prob_loss_cached(radius_A, radius_B, c_dist, targets)

        loss = self.args.probwt * cond_prob_loss + self.args.regularwt * regular_loss
        return loss
    
    def _set_seed(self, seed):
        random.seed(seed)
        np.random.seed(seed)
        torch.manual_seed(seed)
        if self.args.cuda:
            torch.cuda.manual_seed_all(seed)
        torch.backends.cudnn.deterministic = True


    def train_one_step(self, it, encode_A, encode_B, targets):
        self.model.train()
        self.optimizer_pretrain.zero_grad()
        self.optimizer_projection.zero_grad()

        if self.args.cuda and not isinstance(targets, torch.Tensor):
            targets = torch.tensor(targets, dtype=torch.float).cuda()
        elif not isinstance(targets, torch.Tensor):
            targets = torch.tensor(targets, dtype=torch.float)

        loss = self.compute_loss(encode_A, encode_B, targets)
        loss.backward()
        self.optimizer_pretrain.step()
        self.optimizer_projection.step()
        return loss

    def train(self,checkpoint=None,save_path=None):
        self._set_seed(self.args.seed)
        time_tracker = []

        best_KL=1e8; best_Pearson=0; best_Spearman=0

        if checkpoint:
            self.model.load_state_dict(torch.load(checkpoint))
        if save_path is None:
            save_path = os.path.join("../result", self.args.dataset,"model")
            train_path = os.path.join("../result", self.args.dataset,"train")
            if not os.path.exists(save_path):
                os.makedirs(save_path)
            if not os.path.exists(train_path):
                os.makedirs(train_path)
        
        for epoch in tqdm(range(self.args.epochs)):
            train_loss = []
            
            epoch_time = time.time()
            print(f"Epoch {epoch+1}/{self.args.epochs}")
            for it, (encode_A, encode_B, targets) in tqdm(enumerate(self.train_loader), total = len(self.train_loader)):
                loss = self.train_one_step(it, encode_A, encode_B, targets)
                train_loss.append(loss.item())
            
            train_loss = np.average(train_loss)
            test_metrics = self.predict()

            if(test_metrics["KL"] <= best_KL):
                if((test_metrics["Pearson"] >= best_Pearson) or (test_metrics["Spearman"] >= best_Spearman)):
                    best_KL = test_metrics["KL"]
                    best_Pearson = test_metrics["Pearson"]
                    best_Spearman = test_metrics["Spearman"]
                    torch.save(self.model.state_dict(), os.path.join(save_path, f"exp_model_{self.exp_setting}.checkpoint"))

            time_tracker.append(time.time() - epoch_time)
            print(
                "Epoch: {:04d}".format(epoch + 1),
                " train_loss:{:.05f}".format(train_loss),
                "KL:{:.05f}".format(test_metrics["KL"]),
                " Pearson:{:.05f}".format(test_metrics["Pearson"]),
                " Spearman:{:.05f}".format(test_metrics["Spearman"]),
                " epoch_time:{:.01f}s".format(time.time() - epoch_time),
                " remain_time:{:.01f}s".format(np.mean(time_tracker) * (self.args.epochs - (1 + epoch))),
                )   

            torch.save(self.model.state_dict(), os.path.join("../result",self.args.dataset,"train","exp_model_"+self.exp_setting+"_"+str(epoch)+".checkpoint"))
            if epoch:
                os.remove(os.path.join("../result",self.args.dataset,"train","exp_model_"+self.exp_setting+"_"+str((epoch-1))+".checkpoint"))
    
    def metrics(self, pred, target, key=None):
        # Convert lists to numpy arrays for computation
        pred = np.array(pred)
        target = np.array(target)
        
        # Clip values so that log computations don't encounter zeros.
        pred = np.clip(pred, 1e-15, 1 - 1e-15)
        target = np.clip(target, 1e-15, 1 - 1e-15)

        # adjust sum to 1
        pred = pred / np.sum(pred)
        target = target / np.sum(target)
        
        # First metric is KL divergence
        kl_div = np.sum(target * np.log(target / pred))
        
        # Second metric is Pearson Correlation
        pearson_corr = np.corrcoef(pred, target)[0, 1]
        
        # Third metric is Spearman Rank Correlation
        spearman_corr = scipy.stats.spearmanr(pred, target)[0]
        
        metrics = {
            "KL": kl_div,
            "Pearson": pearson_corr,
            "Spearman": spearman_corr
        }
        return metrics

    def predict(self, tag=None, load_model_path=None):
        print("Predicting...")
        if tag=="test":
            model_path = load_model_path if load_model_path else f"../result/{self.args.dataset}/model/exp_model_{self.exp_setting}.checkpoint"
            self.model.load_state_dict(torch.load(model_path))

        self.model.eval()
        pred_prob_distribution = collections.defaultdict(list)
        true_prob_distribution = collections.defaultdict(list)
        with torch.no_grad():
            for it, (encode_A, encode_B, targets, movA, movB) in tqdm(enumerate(self.test_loader), total = len(self.test_loader)):
                center_A, radius_A = self.model(encode_A)
                center_B, radius_B = self.model(encode_B)
                if self.args.cuda and not isinstance(targets, torch.Tensor):
                    targets = torch.tensor(targets, dtype=torch.float).cuda()

                c_dist = self.center_distance(center_A, center_B)
                # Calculating the conditional probability
                conditional_prob = self.condition_score_cached(radius_A, radius_B,c_dist)

                # Using movA and movB as keys
                for i in range(len(movA)):                    
                    pred_prob_distribution[movA[i]].append((movB,conditional_prob[i].item()))
                    true_prob_distribution[movA[i]].append((movB,targets[i].item()))
            
            # Now we need to calculate the metrics for each movie
            # For each key in pred_prob_distribution, we need to calculate the metrics
            # Sort the values in pred_prob_distribution and true_prob_distribution
            # Then extract the pred_prob and true_prob from the sorted lists by accessing the second element of each tuple
            # Pass these pred and true probabilities to the metrics function

            KL=[]
            Pearson=[]
            Spearman=[]
            for key in pred_prob_distribution:
                # Sort the values in pred_prob_distribution and true_prob_distribution
                pred_prob_distribution[key].sort(key=lambda x: x[0], reverse=True)
                true_prob_distribution[key].sort(key=lambda x: x[0], reverse=True)
                # Extract the pred_prob and true_prob from the sorted lists, barring the entry for which x[0] is the key itself
                pred_prob = [x[1] for x in pred_prob_distribution[key] if x[0] != key]
                true_prob = [x[1] for x in true_prob_distribution[key] if x[0] != key]
                # Calculate the metrics
                metrics = self.metrics(pred_prob, true_prob)

                KL.append(metrics["KL"])
                Pearson.append(metrics["Pearson"])
                Spearman.append(metrics["Spearman"])
            # Calculate the average of the metrics
            avg_KL = np.mean(KL)
            avg_Pearson = np.mean(Pearson)
            avg_Spearman = np.mean(Spearman)

        test_metrics = {
            "KL": avg_KL,
            "Pearson": avg_Pearson,
            "Spearman": avg_Spearman
        }
        if(tag=="test"):
            print("Test Metrics:")
            print("Average KL Divergence: ", avg_KL)
            print("Average Pearson Correlation: ", avg_Pearson)
            print("Average Spearman Rank Correlation: ", avg_Spearman)

            with open(f'../results/{self.args.dataset}/res_{self.args.version}.json', 'a+') as f:
                d = vars(self.args)
                expt_details = {
                    "Arguments":d,
                    "Test Metrics":test_metrics
                }
                json.dump(expt_details, f, indent=4)            

            return test_metrics
        else:
            return test_metrics

def preprocess_ml(args, indir, outdir=None):
    input_ratings = os.path.join(indir,'ratings.csv')
    ratings = pd.read_csv(input_ratings,header=0,index_col=None)
    filtered_ratings = ratings[ratings['rating'] > 4]
    movie_counts = filtered_ratings['movieId'].value_counts()
    popular_movies = movie_counts[movie_counts > args.minratingnum].index
    result = filtered_ratings[filtered_ratings['movieId'].isin(popular_movies)]
    print(f"Number of movies with more than {args.minratingnum} ratings: ", len(result['movieId'].unique()))
    # Group by movieId, while aggregating the userId
    result = result.groupby('movieId').agg({'userId': lambda x: list(x)}).reset_index()
    output_filepath = os.path.join(outdir, 'filtered_ratings.csv')
    result.to_csv(output_filepath, index=False)

    input_movies = os.path.join(indir,'movies.csv')
    movies = pd.read_csv(input_movies,header=0,index_col=None)
    # Filter the movies with movieId in the list
    filtered_movies = movies[movies['movieId'].isin(popular_movies)]
    movie_list = filtered_movies['movieId'].tolist()
    # Split the movies into train and test sets, 80:20. First shuffle them, and then split
    np.random.seed(args.seed)
    np.random.shuffle(movie_list)
    # Split the movies into train and test sets
    train_movies = movie_list[:int(0.9*len(movie_list))]
    test_movies = movie_list[int(0.9*len(movie_list)):]

    """
    Now we need to create the id_concepts, concepts_id, and id_contexts dictionaries
    id_concepts: movieId -> title
    concepts_id: title -> movieId
    id_contexts: movieId -> title + genres
    We need this for both train and test sets
    """
    id_concepts = {}
    concepts_id = {}
    id_contexts = {}
    concepts_set = set()
    for i, row in filtered_movies.iterrows():
        id_concepts[row['movieId']] = row['title']
        concepts_id[row['title']] = row['movieId']
        concepts_set.add(row['title'])

    for i, row in filtered_movies.iterrows():
        # Add the genres to the id_contexts
        # The genres are separated by '|'
        
        id_contexts[row['movieId']] = f"{row['title']}" + f"{' '.join(row['genres'].split('|'))}"
    
    # Save the id_contexts to a file using json
    with open(os.path.join(outdir, f'movie_contexts_{args.minratingnum}.json'), 'w') as f:
        json.dump(id_contexts, f)
    
    # Save the id_concepts, and concepts_id to a file using json
    # Create a dictionary which contains the id_concepts dictionary and concepts_id dictionary
    processed_data = {
        'id_concepts': id_concepts,
        'concepts_id': concepts_id,
        'movie_list': movie_list,
        'train_movies': train_movies,
        'test_movies': test_movies
    }

    with open(os.path.join(outdir, 'processed_data.json'), 'w') as f:
        json.dump(processed_data, f)

    # Calculating all samples: A sample is of the format (movieIdA, movieIdB, conditionalProb(A,B))
    train_samples= []
    test_samples = []

    user_sets_by_movie = filtered_ratings.groupby('movieId')['userId'].apply(set).to_dict()
    for i in tqdm(range(len(train_movies))):
        for j in range(i+1, len(train_movies)):
            movieA = train_movies[i]
            movieB = train_movies[j]
            set_users_A = user_sets_by_movie[movieA]
            set_users_B = user_sets_by_movie[movieB]
            prob_A_given_B, prob_B_given_A = calc_conditional_probabilities(set_users_A, set_users_B)
            train_samples.append((movieB, movieA, prob_B_given_A))
            train_samples.append((movieA, movieB, prob_A_given_B))
    
    for i in tqdm(range(len(train_movies))):
        for j in range(len(test_movies)):
            movieA = train_movies[i]
            movieB = test_movies[j]
            set_users_A = user_sets_by_movie[movieA]
            set_users_B = user_sets_by_movie[movieB]
            prob_A_given_B, prob_B_given_A = calc_conditional_probabilities(set_users_A, set_users_B)
            train_samples.append((movieB, movieA, prob_B_given_A))
            test_samples.append((movieA, movieB, prob_A_given_B))
            

    # Save the samples to a file
    train_samples_df = pd.DataFrame(train_samples, columns=['movieA', 'movieB', 'conditionalProb'])
    train_samples_df.to_csv(os.path.join(outdir, f'train_samples_{args.minratingnum}.csv'), index=False)
    print("Train samples saved to: ", os.path.join(outdir, f'train_samples_{args.minratingnum}.csv'))

    # For the test set, we need to create the complete conditional distribution for each test movie
    # Our metrics are KL divergence, Pearson Correlation and Spearman Rank Correlation
    # Given a test movie, we need to find it's conditional distribution with all other movies    
    for i in tqdm(range(len(test_movies))):
        for j in range(i,len(test_movies)):
            movieA = test_movies[i]
            movieB = test_movies[j]

            set_users_A = user_sets_by_movie[movieA]
            set_users_B = user_sets_by_movie[movieB]
            prob_A_given_B, prob_B_given_A = calc_conditional_probabilities(set_users_A, set_users_B)
            test_samples.append((movieB, movieA, prob_B_given_A))
            if(movieA != movieB):
                test_samples.append((movieA, movieB, prob_A_given_B))
    
    # Save the test samples to a file
    test_samples_df = pd.DataFrame(test_samples, columns=['movieA', 'movieB', 'conditionalProb'])
    test_samples_df.to_csv(os.path.join(outdir, f'test_samples_{args.minratingnum}.csv'), index=False)
    print("Test samples saved to: ", os.path.join(outdir, f'test_samples_{args.minratingnum}.csv'))

    return popular_movies

def calc_conditional_probabilities(setA, setB):
    # Given two movies A and B, P(A|B) = P(A and B) / P(B) = Number of users who rated both A and B / Number of users who rated B in filtered_ratings
    intersection = setA.intersection(setB)
    # Calculate the conditional probability
    prob_A_given_B = len(intersection) / len(setB)
    prob_B_given_A = len(intersection) / len(setA)
    return prob_A_given_B, prob_B_given_A

def set_seed(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

def main():
    parser = argparse.ArgumentParser()

    parser.add_argument("--dataset", type=str, default="movielens", help="dataset")
    ## Model parameters
    parser.add_argument("--pre_train", type=str, default="bert", help="Pre_trained model")
    parser.add_argument(
        "--hidden", type=int, default=64, help="dimension of hidden layers in MLP"
    )
    parser.add_argument(
        "--embed_size", type=int, default=6, help="dimension of bubble embeddings"
    )
    parser.add_argument("--dropout", type=float, default=0.05, help="dropout")
    parser.add_argument("--phi", type=float, default=0.03, help="minimum volume of bubble")
    parser.add_argument("--probwt", type=float, default=1.0, help="weight of prob loss")
    parser.add_argument(
        "--regularwt", type=float, default=1.0, help="weight of regularization loss"
    )
    ## Training hyper-parameters
    parser.add_argument("--expID", type=int, default=0, help="-th of experiments")
    parser.add_argument("--epochs", type=int, default=100, help="training epochs")
    parser.add_argument("--batch_size", type=int, default=512, help="training batch size")
    parser.add_argument(
        "--lr", type=float, default=2e-5, help="learning rate for pre-trained model"
    )
    parser.add_argument(
        "--lr_projection",
        type=float,
        default=1e-3,
        help="learning rate for projection layers",
    )
    parser.add_argument("--eps", type=float, default=1e-8, help="adamw_epsilon")
    parser.add_argument("--optim", type=str, default="adamw", help="Optimizer")
    parser.add_argument("--version", type=str, default="spherex", help="version of the model")

    ## Others
    parser.add_argument("--cuda", type=bool, default=True, help="use cuda for training")
    parser.add_argument("--gpu_id", type=int, default=0, help="which gpu")
    parser.add_argument("--seed",type=int,default=42,help="Seed for random generator")
    parser.add_argument("--minratingnum",type=int,default=4000,help="Minimum number of ratings for a movie to be considered")

    args = parser.parse_args()
    args.cuda = True if torch.cuda.is_available() and args.cuda else False
    if args.cuda:
        torch.cuda.set_device(args.gpu_id)
    start_time = time.time()
    print("Start time at : ")
    print_local_time()

    print("Arguments: ", args)
    indir = "../data/ml-latest"
    outdir = "../data/ml-latest/processed"
    
    set_seed(args.seed)
    
    if not os.path.exists(outdir):
        os.makedirs(outdir)
    
    resdir = f"../result/{args.dataset}"
    if not os.path.exists(resdir):
        os.makedirs(resdir)
        
    if(not os.path.exists(os.path.join(outdir, f'train_samples_{args.minratingnum}.csv')) or not os.path.exists(os.path.join(outdir, f'test_samples_{args.minratingnum}.csv')) or not os.path.exists(os.path.join(outdir, f'movie_contexts_{args.minratingnum}.json'))):
        print("Preprocessing the data...")
        preprocess_ml(args,indir,outdir)

    exp = LabelClassfnExp(args)
    exp.train()
    exp.predict(tag="test")

    print("Time used :{:.01f}s".format(time.time() - start_time))
    print("End time at : ")
    print_local_time()
    print("************END***************")
    
if __name__ == "__main__":
    main()
    